package com.zzwx.test.dbunit;

import java.io.File;
import java.io.FileInputStream;
import java.sql.Connection;
import java.util.ArrayList;

import org.dbunit.database.DatabaseConnection;
import org.dbunit.database.IDatabaseConnection;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.xml.FlatXmlDataSet;
import org.dbunit.operation.DatabaseOperation;

import com.zzwx.test.dbunit.base.DBUnitBase;

/**
 * @author Roger
 * @desc 添加测试数据并备份需要备份的表,每张表对应一个xml备份文件 
 */
public class DBUnitEach extends DBUnitBase {

	/**
	 * 添加测试数据并备份数据
	 * 
	 * @param fileName
	 *            文件名
	 * @param tableNames
	 *            表名集合
	 */
	@SuppressWarnings("deprecation")
	public void setUpBackupEach(String fileName, String... tableNames) {
		// JDBC数据库连接
		Connection conn = null;
		// DBUnit数据库连接
		IDatabaseConnection connection = null;
		try {
			conn = getConnection();
			// 获得DB连接
			connection = new DatabaseConnection(conn);
			// 备份数据库测试之前的数据
			backupDataEach(tableNames);
			// 准备数据的读入
			IDataSet dataSet = new FlatXmlDataSet(new FileInputStream(
					testDataPath + fileName));
			connection.createDataSet(new String[] {});
			DatabaseOperation.CLEAN_INSERT.execute(connection, dataSet);
		} catch (Exception e) {
			e.printStackTrace();
		} finally {
			closeCon();
		}
	}

	/**
	 * 通过表名备份数据(每张表一个备份文件)
	 * 
	 * @param tableNames
	 *            表名集合
	 * @throws Exception
	 */
	public void backupDataEach(String... tableNames) throws Exception {
		try {
			if (tableNames != null && tableNames.length > 0) {
				files = new ArrayList<File>();
				for (String tableName : tableNames) {
					super.backupData(tableName, tableName + "_back.xml");
					files.add(file);
				}
			}
		} catch (Exception e) {
			e.printStackTrace();
		} finally {
			closeCon();
		}
	}

	/**
	 * 还原备份的数据
	 */
	public void recoverBackupEach() {
		if (null != files && files.size() > 0) {
			for (File file : files) {
				super.recoverData(file);
			}
		}
	}
}
